{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 在线性回归模型中使用梯度下降法\n",
    "> 主要是上一节总结的损失函数J的表达式的实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(666)\n",
    "x = 2 * np.random.random(size=100) # 产生100个随机数\n",
    "y = x * 3. + 4. + np.random.normal(size=100) #"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = x.reshape(-1, 1) # 有一列，行数自适应。这里应该是100行"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100, 1)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100,)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAGnFJREFUeJzt3X+sXGWdx/HPl7ZKi6wtS92Fi1VITBtZXAo3G6QbFVCLKNLFTdRIgopp3B9G0O1uCYngJhuasInsxs2arrJqJFh+2cVfC6zFmEWLubUtpUIFUZGLK1WortKVS/nuH/dcOnd6zsyZM+fHc57zfiVN7z1zZua5Z2a+85zv832eY+4uAED7HdV0AwAA5SCgA0AkCOgAEAkCOgBEgoAOAJEgoANAJIYGdDO7wcyeNLMHerZdZ2YPmdn9ZvYlM1tabTMBAMPk6aF/VtL5fdvulvRH7v4aST+QdGXJ7QIAjGhoQHf3b0l6qm/bXe7+XPLrdkknVdA2AMAIFpbwGO+XtCXrRjNbL2m9JB1zzDFnrlq1qoSnBIDu2LFjxy/cffmw/cYK6GZ2laTnJN2YtY+7b5a0WZImJyd9ampqnKcEgM4xs5/k2a9wQDezSyW9TdJ5zoIwANC4QgHdzM6X9HeSXu/uz5TbJABAEXnKFm+S9B1JK83scTO7TNInJR0r6W4z22Vmn6q4nQCAIYb20N393SmbP1NBWwAAY2CmKABEgoAOAJEoow4dQMtt3Tmt6+7cpycOHNSJSxdrw9qVWrd6oulmYUQEdKDjtu6c1pW379HBmUOSpOkDB3Xl7XskiaDeMqRcgI677s59LwTzOQdnDum6O/c11CIURUAHOu6JAwdH2o5wEdCBjjtx6eKRtiNcBHSg4zasXanFixbM27Z40QJtWLuyoRahKAZFgY6bG/ikyqX9COgAtG71BAE8AqRcACASBHQAiAQBHQAiQUAHgEgQ0AEgEgR0AIgEAR0AIkFAB4BIMLEIQJS6uMY7AR1AdLq6xjspFwDR6eoa7wR0ANHp6hrvBHQA0enqGu8EdADR6eoa7wyKAmiFUapWurrGOwEdQPCKVK3UtcZ7SOWRBHQAwRtUtVJl8BwWrEMrjySHDiB4TVStzAXr6QMH5TocrLfunH5hn9DKIwnoAILXRNVKnmAdWnkkAR1A8JqoWskTrEMrjySgAwjeutUTuvbi0zSxdLFM0sTSxbr24tMqzVPnCdahlUcyKAqgFeqqWpmzYe3KeQOe0pHBOrTySAI6AKTIG6yzvmiaKGckoAOIXtHgWvSsoKlyRnLoAKKWVn54+ZZdOv3jd80rQSxTU+WMQwO6md1gZk+a2QM9244zs7vN7OHk/2WVthIACkoLrpJ04ODMEXXlZWmqnDFPD/2zks7v27ZR0jfc/VWSvpH8DgDBGRREq+o1N1XOODSgu/u3JD3Vt/kiSZ9Lfv6cpHUltwsASjEsiFbRa26qnLFoDv0P3P1nkpT8/7KsHc1svZlNmdnU/v37Cz4dABSTFlx7VdFrbqJuXqqhysXdN0vaLEmTk5Ne9fMBQK+5IPrxL+/V08/MzLutyl5z3XXzUvEe+s/N7ARJSv5/srwmAUC51q2e0M6PvVnXv/P02nvNdSraQ79D0qWSNiX//0dpLQKAijTRa65TnrLFmyR9R9JKM3vczC7TbCB/k5k9LOlNye8AgAYN7aG7+7szbjqv5LYAAMbA1H8AGFFIl53rRUAHCgj1A43DqnqNQrvsXC8COjCikD/QXTIoYFf5GjV1fdM8WJwLGFFo15HsomHX+6zyNQrtsnO9COjAiEL+QHfFsIBd1Wu0dee0jjJLva2py871IqADIwrtOpJdNCxgV/EazZ0VHPIjJ7w3edm5XgR0YEShXUeyi4YF7Cpeo2vu2Ju6DO8Cs2BmnBLQgRE1tfASDhsWsMt+jbbunNaBgzOptz3vHsxrT5ULUEDsU8hDl+d6n2W+RoMGU0NKtRHQAbRSnV+qgwZTQ0q1kXIBgCGyeuHLliwK6kyNgA5gbFt3TmvNpm06eeNXtWbTtsouvtyUrJz91Ree2lCL0pFyATDPqFPmuzBzNk/OPgTmKTWVVZmcnPSpqanang+IUZXryPQHZ2m2JzqoQmTNpm2aTskxL1uySDs/9uZS2tV1ZrbD3SeH7UfKBWiRYVPex1VkynzWgOHTz8xEl3oJHSkXoEXKWBiqv4d/zqrluueh/Xoi+ZJIM6jK48Sli1N76JL00Zt364otuwqfSRQ9G+nqapgEdKBF8qxRMuoqhF/Y/tjQ5x1Ua71h7UpdvmVX6m1z0+SL5NWL5Oa37pw+4mLQMeb0s5ByAQKUVTUybMp7kVUIh1l0lA2stV63ekJLFy8a+jijrnY4avpn7m/vDeZFn7utCOhAYAYF5WFT3ouuQjhQ+uKC81zz9lOPaFeaUZ5/1BUTh31ZFV1psU0lmQR0IDDD8uSD1igpugrhIDOHfGjvtr9dC0pYYnbUFROHBewif3vVg9BlI4cOBGZYUB405T1rgLJ3FcL+ssRx2tSrt11Z5Y+jTJNPa+ugxxg0OFt0pcWQr06Uhh46EJiXZuSj8/Qwi6xCeMlZK0rtWWc9z6irHY76GGl/uyQtXbyo8EqLbbuYCT10ICBbd07rt88+d8T2YQOTc8ZdhbCMnnWe56niMaqYzTnsjCc0BHQgINfduU8zh46sBn/J0QtHCmxFg1hbprhnKXsFxlHTPk0joAMByTqVP5BSilcV1no/rG1fcAR0ICChnuLnmXkZ6+zMNn3BMSgKBCTE65Vu3TmtDbfsnle6t+GW3fNK99pW3hcrAjoQkBCvV3rNHXs18/z8vP7M865r7tj7wu9FFvVC+Ui5IFixnsIPE9opftbFkXu3t628L1YEdASpCxdNiMk4uf/Yvrib/HtIuSBInMKHY9mS9IlOvduL5v5jy703/fcQ0BEkTuHDcfWFp2rRgvkzSBctsHnX0yya+8/64r58y67gF8JK03RHhJQLghRq+V4X5a3FLpL7H/QF3cY0W9MdEQI6gtS2GXqxq2qgdtCCWlLYC2GlabojQsoFQQqhfK9N62CHYtRjlrWgVq82pdmankcwVg/dzK6Q9AFJLmmPpPe5+/+V0TCgyfK9mKtsqqrCKHLMetM5WT31NqXZml4qwNyzLgs75I5mE5L+W9Kr3f2gmd0s6Wvu/tms+0xOTvrU1FSh5wPqtGbTttQAM7F0se7deG4DLSpH1mqKZZz9jHvMqmxb25nZDnefHLbfuCmXhZIWm9lCSUskPTHm4wFBaHpwqypVVmGMe8xCSLO1XeGUi7tPm9k/SnpM0kFJd7n7Xf37mdl6SeslacWKFUWfDqhV04NbVanyi6qMYxbaLNm2KdxDN7Nlki6SdLKkEyUdY2aX9O/n7pvdfdLdJ5cvX168pUCNmh7cqsqo1+kcRazHrE3GSbm8UdKP3H2/u89Iul3S2eU0C2hWrKf/owbdUapWYj1mbTJOlctjks4ysyWaTbmcJ4kRT0QjxtP/UaowilatxHbM2mScHPp9ZnarpO9Jek7STkmby2oYgGrkDbptu+I9xqxDd/erJV1dUluAI8S2El/I+o91Vl142yt9YsbUfwQr5sk9oUk71qbZGYP92l7pEzOm/iNYTa9c1yVpx9olWd9+VK2EjYCOYMU6uSdEWcfUJapWWoSUC4IV6+SeEGUd6wVmjFu0CD30iLV9tUAmqtQna9XDQ+6tvoJQ1xDQI9X0pbDKwESV+swd6wXWnzVn3KJNSLlEKpYaYiaq1Gfd6gldsWVX6m2MW7QDAT1SDCgOR437kRi3aDdSLpGqchGmGMSQkqoC4xbtRkCPVB0fzDYPulLjno5xi3Yj5RKpqi+F1fZZnKSksjFu0V4E9IhV+cFs+6AruWLEiIAekToH+ULu4eY5DhvWrky9fmXRlBQDrAgBAT0SdadAQu3h5j0OZaak2p5+QjwI6JEoIwUySi+z7B5uWe0b5TiUlZJqW/qJs4l4EdADk/Zhk4b3JMdNgYzay6x60LVo+5pIBYWUfhoWrDmbiBsBPSBpH7YNt+yWTJo55C9sS/sAjpsCKdLLrLMaIm/7mkgFhZJ+yhOs23Y2gdFQhx6QtA/bzPP+QjCfk1YvPW7deUi9TOnIGve8V89pYmJMKJNx8tTWh/Y6o1z00Gs07HR4lA9V/77jpkBC6WVKo10956WLF2nNpm3z/uZrLz6t1hxx3emnLHmCdUivM8pHQK9JntPhQddx7Jf2ARwnBdLEIGeWQVfP6Q3qi44y/fbZ53Tg4Iykw8f02otP070bz62tvVIYk3HyBOuQXmeUj5RLTfKcDqedui86yrRowfwlTav4AIY05Tvv1XNecvTCXOmorsiT+gnpdUb56KHXJM/pcNape9q2Kj6ATfQy09JQWT3NiaWL5/W8T9741dTHDC0fXFeZYN7UTwhnE6gGAb0meXOXWR+2GD+AWWmod5w5odt2TA9NC7QhH1x3mSDButtIudQklEqIkGSloe55aH+utECdx7ToypKs6og60UOvSdrp8Dmrluu6O/fpii27Ojljb1AaKk9Ps67qknF62ZQJok4E9Br1BqmuzNgblD8uI2VSVophUDvHmYzThrQQ4kHKpSFdOBUfdlWgUNJQw9o5Ti87lL8R3UBAb0gXTsWHfWmFUkI3rJ3jXM4vlL8R3UDKpSFdOBXPW6rZdHAb1s5xJ+OE8DeiG+ihN6QLp+JlX6i6qmuYDmsnvWy0BT30hoSy/keVypxmXuUgcp520stGGxDQGxR7kCjzS6vKZV+78OWKbiCgo1JlfWlVPYgc+5cruqE1AZ3LZnVbFwaRgXG1YlB0WJ0w4teFQWRgXGMFdDNbama3mtlDZvagmb22rIb16sIkHAxGpQkw3Lgpl3+S9J/u/udm9iJJS0po0xG6MAkHw5HnBgYr3EM3s9+T9DpJn5Ekd3/W3Q+U1bBeZdczA0CMxkm5nCJpv6R/N7OdZvZpMzumfyczW29mU2Y2tX///kJPRP4Udahq4hJQl3EC+kJJZ0j6V3dfLem3kjb27+Tum9190t0nly9fXuiJyJ+iagy8Iwbj5NAfl/S4u9+X/H6rUgJ6WcifokpVTlwC6lK4h+7u/yPpp2Y2l/c4T9L3S2kVUDMG3hGDcatcPiTpxqTC5VFJ7xu/SWFhQlM3MHEJMRgroLv7LkmTJbUlOF25qlDblfGlW+ZCYkBTWjP1vwnkVZuVJ1CX9aXLAl2IAQF9APKqzckbqMv80mXgHW3XirVcmsKEpubkXe6BL13gMAL6AExoak7eQM2XLnAYAX0AJjQ1J2+g5ksXOKzzOfRhA2/kVZuRt+qEwUzgsE4HdMoSwzVKoOZLF5gVVUAftR45hrLEmCc+EaiB0UQT0Iv0tttWIdEfvM9ZtVy37ZgO6gwj5i8YIHTRDIoWuapRmyok0lYDvHH7Y7VeyWnY8rKsWAg0K5qAXqS33aYKibQvLM/Yt4ozjDzBmksFAs2KJqAX6W23qSxxlCBdxRlGnmDdthQWEJtocuhFF1dqy8Bb1mqApvk99SJnGHny3nmCNSsWAs2Kpofept52EVnpofectWKsvzlv3jvPGVCbUlhAjKLpoUvt6W0XUdUEmrylm3nOgJjkAzQrqoAeuyq+sPLmvfMG69C/VCmrRMwI6BVqMnjkfe5R8t6hB+thmBmM2EWTQ08zrG666uduqiZ7lOfuUt6bskrELtqA3vQkl6zg8dGbd1fehlECV+yDyb0oq0Tsok25ZAW1a+7YW0uwygoSh9wrP80fNXC1PZWSF2WViF20PfSs4HXg4EwtvfRBQaLq0/w2LWlQpy6ll9BN0Qb0QcGrjpxpWvDoVeVpPoErXZfSS+imaFMuG9au1OVbdqXeVkfOdC5IfPTm3TrkR666MugLZ9zqGOrBs3UlvYRuijagr1s9oY9/ea+efmbmiNvmgmnVZYVzjzXKkgRlldYRuIDuiTblIklXX3hqZuqhriqYUU/zKa0DUFS0PXRpcOphzaZttV2taJTeMqV1AIqKOqBL2cE01MBJaR2AoqJOuQwSamkfFSoAiupsQA81cFJaB6Co1qRcyq5ICbm0jwoVAEW0IqBXtUoegRMSS+oiHq0I6HkvwhCDJoJLlwMaS+oiJq3IoYdakVK2JlaIbHpVyqZR94+YtCKgh1qRUrYmgkvXA1pXOgvohlakXLKuZ3nOquVas2lbbamCqlMTTQSXrgc06v4Rk7F76Ga2wMx2mtlXymhQmt5SPklaYKaDM4d04/bHaksV1JGaaOJMpCtnP1lCLV8Fiigj5fJhSQ+W8DgDrVs98cKHb271wv41DKtMFdSRmmgiuHQ9oFH3j5iMlXIxs5MkvVXSP0j6SCktGiAtqParKlWQ9bhpp+tFNVEbH3I9fl0oX0Usxs2hXy/pbyUdm7WDma2XtF6SVqxYMdaT5QnWVaUKsnKtptl0TFkBoYngQkAD4lA45WJmb5P0pLvvGLSfu29290l3n1y+fHnRp5M0PFhXmSrYsHalLGW7q54rII1r685prdm0TSdv/KrWbNrWmbJEoEvGyaGvkfR2M/uxpC9KOtfMvlBKqzKk5XvngmzVuc91qyeOyNnPCb0ipOu15kBXFE65uPuVkq6UJDN7g6S/cfdLSmpXqqbzvRMtLXHr0kxboMtaUYfeq858b3/d+Tmrluu2HdO5LycXiq7XmgNdUcpMUXf/pru/rYzHCkVamuK2HdN6x5kTrStx63qtOdAVrZj634SsNMUXtj8mSfrEO0/XvRvPDT6YS9SaA13RupRLXQalI6pYka/KZQWaHnsAUA8CeoasuvM5ZQ4q1rGEK7XmQPxIuWRIS1P0K2tQsesrHgIoBz30DL1piqyeelmDilShACgDPfQB1q2e0L0bz9X17zy90kFFqlAAlIGAnkPVK/JRhQKgDKRccqpyUJEqFABlIKAHoj+ozw2IEtQB5EVADwRXnwcwLnLogaB0EcC4COiBoHQRwLgI6IGgdBHAuAjogaB0EcC4WjkoWuVCVk2hdBHAuFoX0GOuBmEBLQDjaF3KhWoQAEjXuoBONQgApGtdQKcaBADStS6gUw0CAOlaNyhKNQgApGtdQJeoBgGANK1LuQAA0rWyhx7jxCIAGFfrAnrME4sAYBytS7kwsQgA0rUuoDOxCADStS6gM7EIANK1LqAzsQgA0rVuUJSJRQCQrnUBXWJiEQCkaV3KBQCQjoAOAJEgoANAJAjoABAJAjoARMLcvb4nM9sv6ScF7368pF+U2Jyy0K7R0K7R0K7RxNquV7j78mE71RrQx2FmU+4+2XQ7+tGu0dCu0dCu0XS9XaRcACASBHQAiESbAvrmphuQgXaNhnaNhnaNptPtak0OHQAwWJt66ACAAQjoABCJIAK6mZ1vZvvM7BEz25hy+4vNbEty+31m9sqe265Mtu8zs7U1t+sjZvZ9M7vfzL5hZq/oue2Qme1K/t1Rc7vea2b7e57/Az23XWpmDyf/Lq25XZ/oadMPzOxAz22VHC8zu8HMnjSzBzJuNzP756TN95vZGT23VXmshrXrPUl77jezb5vZH/fc9mMz25Mcq6ma2/UGM/tVz2v1sZ7bBr7+FbdrQ0+bHkjeT8clt1VyvMzs5WZ2j5k9aGZ7zezDKfvU+/5y90b/SVog6YeSTpH0Ikm7Jb26b5+/lPSp5Od3SdqS/PzqZP8XSzo5eZwFNbbrHElLkp//Yq5dye+/afB4vVfSJ1Pue5ykR5P/lyU/L6urXX37f0jSDTUcr9dJOkPSAxm3XyDp65JM0lmS7qv6WOVs19lzzyfpLXPtSn7/saTjGzpeb5D0lXFf/7Lb1bfvhZK2VX28JJ0g6Yzk52Ml/SDls1jr+yuEHvqfSHrE3R9192clfVHSRX37XCTpc8nPt0o6z8ws2f5Fd/+du/9I0iPJ49XSLne/x92fSX7dLumkkp57rHYNsFbS3e7+lLs/LeluSec31K53S7qppOfO5O7fkvTUgF0ukvR5n7Vd0lIzO0HVHquh7XL3byfPK9X33spzvLKM874su111vbd+5u7fS37+X0kPSuq/UEOt768QAvqEpJ/2/P64jjwoL+zj7s9J+pWk38953yrb1esyzX4TzznazKbMbLuZrSupTaO06x3JKd6tZvbyEe9bZbuUpKZOlrStZ3NVx2uYrHZXeaxG1f/eckl3mdkOM1vfQHtea2a7zezrZnZqsi2I42VmSzQbGG/r2Vz58bLZNPBqSff13VTr+yuEKxZZyrb+WsqsffLct6jcj21ml0ialPT6ns0r3P0JMztF0jYz2+PuP6ypXV+WdJO7/87MPqjZs5tzc963ynbNeZekW939UM+2qo7XME28t3Izs3M0G9D/tGfzmuRYvUzS3Wb2UNKDrcP3NLuuyG/M7AJJWyW9SoEcL82mW+51997efKXHy8xeotkvkMvd/df9N6fcpbL3Vwg99Mclvbzn95MkPZG1j5ktlPRSzZ5+5blvle2Smb1R0lWS3u7uv5vb7u5PJP8/Kumbmv32rqVd7v7Lnrb8m6Qz8963ynb1eJf6TokrPF7DZLW7ymOVi5m9RtKnJV3k7r+c295zrJ6U9CWVl2Ycyt1/7e6/SX7+mqRFZna8AjheiUHvrdKPl5kt0mwwv9Hdb0/Zpd73V9kDBQUGFhZqdkDgZB0eTDm1b5+/0vxB0ZuTn0/V/EHRR1XeoGiedq3W7EDQq/q2L5P04uTn4yU9rJIGiHK264Sen/9M0nY/PBDzo6R9y5Kfj6urXcl+KzU7SGV1HK/kMV+p7EG+t2r+oNV3qz5WOdu1QrNjQmf3bT9G0rE9P39b0vk1tusP5147zQbGx5Jjl+v1r6pdye1zHb1j6jheyd/9eUnXD9in1vdXaQd7zANzgWZHiH8o6apk299rttcrSUdLuiV5g39X0ik9970qud8+SW+puV3/JennknYl/+5Itp8taU/ypt4j6bKa23WtpL3J898jaVXPfd+fHMdHJL2vznYlv18jaVPf/So7Xprtrf1M0oxme0WXSfqgpA8mt5ukf0navEfSZE3Hali7Pi3p6Z731lSy/ZTkOO1OXuOram7XX/e8t7ar5wsn7fWvq13JPu/VbJFE7/0qO16aTYO5pPt7XqcLmnx/MfUfACIRQg4dAFACAjoARIKADgCRIKADQCQI6AAQCQI6AESCgA4Akfh/3IBy0SKTn8QAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(x, y)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用梯度下降法训练\n",
    "> 这列的m是指样本的数量；n是指变量的个数(x1,x2,x3等)\n",
    "\n",
    "![梯度下降法.png](images/多元线性回归的损失函数J以及其导数dJ的表达式.png)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def J(theta, X_b, y):\n",
    "    try:\n",
    "        return np.sum((y - X_b.dot(theta))**2) / len(X_b)\n",
    "    except:\n",
    "        return float(\"inf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dJ(theta, X_b, y):\n",
    "    result = np.empty(len(theta))\n",
    "    result[0] = np.sum(X_b.dot(theta) - y) # dJ的第一项Xi是1，形式和其他的n项不一样\n",
    "    for i in range(1, len(theta)): # 处理剩下的m项，形式都是一样地了\n",
    "        result[i] = (X_b.dot(theta) - y).dot(X_b[:, i])\n",
    "    return result * 2 / len(X_b) # 结果乘以2/m，m是样本数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gradient_descent(X_b, y, initial_theta,eta, epsilon=1e-8, max_iters=10000): # max_iters:最大循环次数\n",
    "    theta = initial_theta\n",
    "    for i in range(0, max_iters):\n",
    "        gradient = dJ(theta, X_b, y)\n",
    "        last_theta = theta \n",
    "        theta = theta - eta * gradient\n",
    "\n",
    "        # 比较差值是否越来越小.当变换接近于最小精度时，已经可以认为到达导数为0的点了,此时损失函数最小，说明拟合地最好\n",
    "        if(abs(J(theta, X_b, y) - J(last_theta, X_b, y)) < epsilon):\n",
    "            break\n",
    "    return theta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_b = np.hstack([np.ones((len(X), 1)), X.reshape(-1, 1)]) # 沿着X的方向进行堆叠，添加一列，行数为样本个数(m)，列数为1+n\n",
    "initial_theta = np.zeros(X_b.shape[1]) # X的的参数为theta，因为多个样本组成n*m的矩阵，所以初始化theta数组的长度为n+1，即Xb的列数\n",
    "eta = 0.01\n",
    "theta = gradient_descent(X_b, y, initial_theta, eta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 4.02145786,  3.00706277])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "theta # 此时拿到的θ即为多元线性回归的系数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试封装的线性回归算法(playML)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from playML.LinearRegression import LinearRegression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "lin_reg = LinearRegression() # 创建线性回归训练器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LinearRegression()"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lin_reg.fit_gd(X, y) # gd为gradient descent的意思，代表是用地梯度下降法进行数据处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 3.00706277])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lin_reg.coef_ # 斜率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4.021457858204859"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lin_reg.intercept_ # 截距"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
